import os
import sys
import argparse
import asyncio
import threading
import uuid
import time
import base64
import requests
import numpy as np
from io import BytesIO
from typing import List

from fastapi import FastAPI, Request, BackgroundTasks
from fastapi.responses import StreamingResponse, JSONResponse
from PIL import Image
import cv2
import torch

# MODIFIED: Changed imports to be consistent with the anonymized project structure.
# This assumes 'serve' is a sub-package of 'my_project'.
from my_project.serve.constants import WORKER_HEART_BEAT_INTERVAL, ErrorCode, SERVER_ERROR_MSG
from my_project.serve.utils import build_logger, pretty_print_semaphore

# NOTE: The following 'RemoteSAM' imports now assume that the RemoteSAM library
# has been properly installed in the environment (e.g., via 'pip install -e .')
# or is accessible via the PYTHONPATH environment variable.
# The hardcoded path manipulations have been removed.
from tasks.code.model import RemoteSAM, init_demo_model
import utils as remotesam_utils


GB = 1 << 30

now_file_name = os.path.basename(__file__)
logdir = "logs/workers/"
os.makedirs(logdir, exist_ok=True)
logfile = os.path.join(logdir, f"{now_file_name}.log")

worker_id = str(uuid.uuid4())[:6]
logger = build_logger(now_file_name, logfile)
global_counter = 0

model_semaphore = None

# ===============================
# Configuration: Whether to output pixel coordinates (directly aligned with GT)
PIXEL_OUTPUT = True
# ===============================


def heart_beat_worker(controller):
    while True:
        time.sleep(WORKER_HEART_BEAT_INTERVAL)
        controller.send_heart_beat()


class ModelWorker:
    def __init__(
        self,
        controller_addr: str,
        worker_addr: str,
        worker_id: str,
        no_register: bool,
        checkpoint_path: str,
        device: str,
        use_epoc: bool = False,
    ):
        self.controller_addr = controller_addr
        self.worker_addr = worker_addr
        self.worker_id = worker_id
        self.device = device
        self.orig_image_size = None  # Used to record original image size

        logger.info(f"Loading RemoteSAM checkpoint from {checkpoint_path} on {device} ...")

        # Backup and clear sys.argv to prevent interference with init_demo_model
        _argv_backup = sys.argv.copy()
        sys.argv = [_argv_backup[0]]
        try:
            base_model = init_demo_model(checkpoint_path, device)
        finally:
            sys.argv = _argv_backup # Restore sys.argv
        
        self.model = RemoteSAM(base_model, device, use_EPOC=use_epoc)
        logger.info("RemoteSAM model loaded successfully")

        if not no_register:
            self.register_to_controller()
            self.heart_beat_thread = threading.Thread(target=heart_beat_worker, args=(self,))
            self.heart_beat_thread.start()

    # ---------------------- Controller -----------------------------
    def register_to_controller(self):
        url = self.controller_addr + "/register_worker"
        data = {
            "worker_name": self.worker_addr,
            "check_heart_beat": True,
            "worker_status": self.get_status(),
        }
        r = requests.post(url, json=data)
        assert r.status_code == 200
        logger.info("Registered to controller")

    def send_heart_beat(self):
        url = self.controller_addr + "/receive_heart_beat"
        while True:
            try:
                ret = requests.post(url, json={"worker_name": self.worker_addr, "queue_length": self.get_queue_length()}, timeout=5)
                exist = ret.json()["exist"]
                break
            except requests.exceptions.RequestException as e:
                logger.error(f"heart beat error: {e}")
                time.sleep(5)
        if not exist:
            self.register_to_controller()

    def get_queue_length(self):
        if model_semaphore is None or not hasattr(model_semaphore, '_value') or not hasattr(model_semaphore, '_waiters'):
            return 0
        return args.limit_model_concurrency - model_semaphore._value + len(model_semaphore._waiters)

    def get_status(self):
        # Register as tool name to controller for TOOL_SCHEMA calls
        return {"model_names": ["remote_sam"], "speed": 1, "queue_length": self.get_queue_length()}

    # ---------------------- Utility -----------------------------
    def load_image(self, image_str: str) -> str:
        """Accept file path or base64 string, return local path"""
        if os.path.exists(image_str):
            try:
                with Image.open(image_str) as img:
                    self.orig_image_size = (img.height, img.width)  # (h, w)
            except:
                self.orig_image_size = None
            return image_str
        
        try:
            if image_str.startswith('data:image/'):
                image_str = image_str.split(',', 1)[1]
            img_data = base64.b64decode(image_str)
            temp_dir = "temp_images"
            os.makedirs(temp_dir, exist_ok=True)
            temp_path = os.path.join(temp_dir, f"{uuid.uuid4()}.jpg")
            with open(temp_path, 'wb') as f:
                f.write(img_data)
            
            try:
                with Image.open(temp_path) as img:
                    self.orig_image_size = (img.height, img.width)
            except:
                self.orig_image_size = None
                
            logger.info(f"Base64 image decoded and saved as temporary file: {temp_path}")
            return temp_path
        except Exception as e:
            logger.error(f"Failed to decode image: {e}")
            raise ValueError("Invalid image input")

    # ---------------------- Inference -----------------------------
    def generate_stream_func(self, params):
        caption = params.get("caption", "").strip()
        
        if not caption:
            prompt = params.get("prompt", "").strip()
            if prompt:
                import re
                user_match = re.search(r"USER:.*?<image>\s*(.*?)\s*ASSISTANT:", prompt, re.DOTALL)
                caption = user_match.group(1).strip() if user_match else prompt.replace("USER:", "").replace("ASSISTANT:", "").replace("<image>", "").strip()
        
        logger.info(f"Extracted caption: '{caption}'")
        
        if not caption:
            return {"boxes": [], "phrases": [], "size": [0, 0], "error": "Please provide a text description."}
        
        image_input = params.get("image") or (params.get("images")[0] if isinstance(params.get("images"), list) and params.get("images") else None)
        
        if not image_input:
            return {"boxes": [], "phrases": [], "size": [0, 0], "error": "No image provided."}

        try:
            img_path = self.load_image(image_input)
            img_bgr = cv2.imread(img_path)
            if img_bgr is None:
                raise ValueError("failed to read image")
            img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
            h, w, _ = img_rgb.shape

            box = self.model.visual_grounding(image=img_rgb, sentence=caption)
            
            boxes_multi = []
            try:
                mask, prob = self.model.referring_seg(image=img_rgb, sentence=caption, return_prob=True)
                # NOTE: This now relies on 'remotesam_utils' being available in the environment
                boxes_multi = remotesam_utils.M2B(mask, prob, box_type='hbb')
                if len(boxes_multi) <= 1:
                    boxes_multi = [] # Fall back to single box if only one is found
            except Exception as e:
                logger.warning(f"Multi-box detection failed: {e}")
                boxes_multi = []
            
            orig_h, orig_w = self.orig_image_size if self.orig_image_size else (h, w)
            proc_h, proc_w = h, w
            
            boxes_out, phrases = [], []
            box_list = boxes_multi if boxes_multi else ([box] if box else [])
            
            for i, box_item in enumerate(box_list):
                if box_item and len(box_item) == 4:
                    if len(box_list) == 1 and not boxes_multi: # Single-box mode from visual_grounding (x, y, w, h)
                        x, y, w_box, h_box = map(float, box_item)
                        x1, y1, x2, y2 = x, y, x + w_box, y + h_box
                    else: # Multi-box mode (x1, y1, x2, y2)
                        x1, y1, x2, y2 = map(float, box_item[:4])
                    
                    boxes_out.append([x1, y1, x2, y2] if PIXEL_OUTPUT else [round(x1 / w, 4), round(y1 / h, 4), round(x2 / w, 4), round(y2 / h, 4)])
                    phrase_suffix = f"-{i+1}" if len(box_list) > 1 else ""
                    phrases.append(f"{caption.rstrip('.').strip()}{phrase_suffix}")

            return {
                "boxes": boxes_out, "phrases": phrases, "size": [proc_h, proc_w],
                "orig_size": [orig_h, orig_w], "caption_used": caption
            }
        except Exception as e:
            logger.error(f"RemoteSAM processing error: {e}")
            return {"boxes": [], "phrases": [], "size": [0, 0], "error": f"Processing failed: {str(e)}"}

    def generate_gate(self, params):
        try:
            out = self.generate_stream_func(params)
            if "error" in out:
                return {"text": f"RemoteSAM: {out['error']}", "error_code": 0}
            return {"text": out, "error_code": 0}
        except Exception as e:
            logger.error(f"inference error: {e}")
            return {"text": f"{SERVER_ERROR_MSG}\n\n({e})", "error_code": ErrorCode.INTERNAL_ERROR}

app = FastAPI()

async def acquire_model_semaphore():
    global model_semaphore
    if model_semaphore is None:
        model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
    await model_semaphore.acquire()

def release_model_semaphore():
    model_semaphore.release()

@app.post("/worker_generate")
async def api_generate(request: Request):
    params = await request.json()
    await acquire_model_semaphore()
    output = worker.generate_gate(params)
    release_model_semaphore()
    return JSONResponse(output)

@app.post("/worker_generate_stream")
async def api_generate_stream(request: Request):
    params = await request.json()
    await acquire_model_semaphore()
    output = worker.generate_gate(params)
    
    def generate():
        yield json.dumps(output).encode() + b"\0"
    
    tasks = BackgroundTasks()
    tasks.add_task(release_model_semaphore)
    return StreamingResponse(generate(), media_type="application/octet-stream", background=tasks)

@app.post("/worker_get_status")
async def api_get_status(request: Request):
    return worker.get_status()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="localhost")
    parser.add_argument("--port", type=int, default=20003)
    parser.add_argument("--worker-address", type=str, default="http://localhost:20003")
    parser.add_argument("--controller-address", type=str, default="http://localhost:20001")
    parser.add_argument("--checkpoint", type=str, required=True, help="Path to RemoteSAM checkpoint .pth")
    parser.add_argument("--device", type=str, default="cuda:0")
    parser.add_argument("--limit-model-concurrency", type=int, default=2)
    parser.add_argument("--no-register", action="store_true")
    parser.add_argument("--use-epoc", action="store_true", help="Enable EPOC for RemoteSAM")
    args = parser.parse_args()

    worker = ModelWorker(
        controller_addr=args.controller_address,
        worker_addr=args.worker_address,
        worker_id=worker_id,
        no_register=args.no_register,
        checkpoint_path=args.checkpoint,
        device=args.device,
        use_epoc=args.use_epoc,
    )

    import uvicorn
    uvicorn.run(app, host=args.host, port=args.port, log_level="info")